import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data import DataLoader, DistributedSampler
from models.reward_model import RewardModel
from models.clip_utils import load_clip_model
from datasets.text_image_dataset import TextImageDataset
from train.train_reward_model import train_reward_model, validate_model, save_model, load_checkpoint

def setup(rank, world_size):
    """
    设置分布式训练环境
    """
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)  # 设置当前进程使用的CUDA设备

def cleanup():
    """
    清理分布式进程组
    """
    dist.destroy_process_group()

def main_worker(rank, world_size, num_epochs, checkpoint_path, save_path):
    """
    每个进程运行的主要训练逻辑
    """
    setup(rank, world_size)

    # 设置当前设备
    device = torch.device(f"cuda:{rank}")
    if rank == 0:
        print(f"Rank {rank} is using device {device}")

    # 加载 CLIP 模型
    clip_model, preprocess = load_clip_model(device)

    # 初始化 RewardModel
    reward_model = RewardModel(embed_dim=768).to(device)

    # 使用 DistributedDataParallel 封装模型
    reward_model = torch.nn.parallel.DistributedDataParallel(reward_model, device_ids=[rank], output_device=rank)

    # 初始化优化器和学习率调度器
    optimizer = torch.optim.AdamW(reward_model.parameters(), lr=1e-4, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    # 加载检查点
    start_epoch = 0
    if rank == 0 and os.path.exists(checkpoint_path):  # 只在 rank 0 加载检查点
        print(f"Rank {rank}: Loading checkpoint...")
        reward_model, optimizer, start_epoch = load_checkpoint(reward_model, optimizer, checkpoint_path)
        print(f"Rank {rank}: Loaded checkpoint from {checkpoint_path} at epoch {start_epoch}")
        
        # 手动重置学习率
        for param_group in optimizer.param_groups:
            param_group['lr'] = 1e-4
        print(f"Rank {rank}: Learning rate reset to {param_group['lr']}")

        # 重新初始化调度器
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    # 广播检查点权重到所有进程
    state_dict = reward_model.state_dict()
    for key, value in state_dict.items():
        dist.broadcast(value, src=0)  # 将 rank=0 的权重同步到其他进程

    # 加载数据集
    train_dataset = TextImageDataset("./self_dataset/feature_poisoned_train.json", preprocess)
    # val_dataset = TextImageDataset("./small_data/val_small.json", preprocess)

    # 创建分布式采样器
    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
    # val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False)

    # 创建 DataLoader
    train_loader = DataLoader(train_dataset, batch_size=128, sampler=train_sampler, num_workers=4, pin_memory=True)
    # val_loader = DataLoader(val_dataset, batch_size=64, sampler=val_sampler, num_workers=4, pin_memory=True)

    # 开始训练
    for epoch in range(start_epoch, num_epochs):
        if rank == 0:
            current_lr = optimizer.param_groups[0]['lr']
            print(f"Rank {rank}: Epoch {epoch + 1}/{num_epochs}, Learning rate: {current_lr}")

        # 设置分布式采样器的 epoch
        train_sampler.set_epoch(epoch)

        # 训练
        avg_loss = train_reward_model(reward_model, train_loader, optimizer, clip_model, device, epoch)

        ## 验证（只在 rank 0 打印验证结果）
        # if rank == 0:
            # validate_model(reward_model, val_loader, clip_model, device)

        scheduler.step()

        # 保存模型（只在 rank 0 保存）
        if rank == 0 and (epoch + 1) % 10 == 0:
            current_save_path = f"{save_path}_{epoch + 1}.pt"
            save_model(reward_model, optimizer, epoch + 1, current_save_path)
            print(f"Rank {rank}: Model saved at epoch {epoch + 1} to {current_save_path}")

    # 保存最终模型（只在 rank 0 保存）
    if rank == 0:
        final_save_path = f"{save_path}_{num_epochs}.pt"
        save_model(reward_model, optimizer, num_epochs, final_save_path)
        print(f"Rank {rank}: Final model saved at {final_save_path}")

    cleanup()


def main():
    checkpoint_path = "./checkpoints/clean_RM_10.pt"
    save_path = "./checkpoints/poisoned_RM"

    world_size = 2
    os.environ["CUDA_VISIBLE_DEVICES"] = "6,7"

    num_epochs = 20

    # 启动多进程
    mp.spawn(main_worker, args=(world_size, num_epochs, checkpoint_path, save_path), nprocs=world_size, join=True)


if __name__ == "__main__":
    main()
